#!/usr/bin/env python

"""
This utility allows the CIME user to view and modify a case's PE layout.
With this script, a user can:

1) View the PE layout of a case
   ./pelayout
   ./pelayout --format "%C:  %06T/%+H" --header "Comp: Tasks /Th"
2) Attempt to scale the number of tasks used by a case
   ./pelayout --set-ntasks 144
3) Set the number of threads used by a case
   ./pelayout --set-nthrds 2

The --set-ntasks option attempts to scale all components so that the
job will run in the provided number of tasks. For a component using the
maximum number of tasks, this will merely set that component to the new
number. However, for components running in parallel using a portion of
the maximum tasks, --set-ntasks will attempt to scale the tasks
proportionally, changing the value of ROOTPE to maintain the same level
of parallel behavior. If the --set-ntasks algorithm is unable to
automatically find a new layout, it will print an error message
indicating the component(s) it was unable to reset and no changes will
be made to the case.

Interpreted FORMAT sequences are:
%%  a literal %
%C  the component name
%T  the task count for the component
%H  the thread count for the component
%R  the PE root for the component

Standard format extensions, such as a field length and padding are supported.
Python dictionary-format strings are also supported. For instance,
--format "{C:4}", will print the component name padded to 4 spaces.

If  you encounter problems with this tool or find it is missing any
feature that you need, please open an issue on https://github.com/ESMCI/cime
"""

from standard_script_setup import *

from CIME.case import Case
from CIME.utils import expect, convert_to_string
import sys
import re

logger = logging.getLogger(__name__)

###############################################################################
def parse_command_line(args, description):
###############################################################################
    # Start with usage description
    parser = argparse.ArgumentParser(description=description ,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    CIME.utils.setup_standard_logging_options(parser)

    # Set command line options
    parser.add_argument("--set-ntasks", default=None,
                        help="Total number of tasks to set for the case")

    parser.add_argument("--set-nthrds", "--set-nthreads", default=None,
                        help="Number of threads to set for all components")

    parser.add_argument("--format",
                        default="%4C: %6T/%6H; %6R",
                        help="Format the PE layout items for each component (see below)")

    parser.add_argument("--header",
                        default="Comp  NTASKS  NTHRDS  ROOTPE",
                        help="Custom header for PE layout display")

    parser.add_argument("--no-header", default=False , action="store_true" ,
                        help="Do not print any PE layout header")

    parser.add_argument("--caseroot", default=os.getcwd(),
                        help="Case directory to reference")

    args = CIME.utils.parse_args_and_handle_standard_logging_options(args, parser)

    if (args.no_header):
        args.header = None
    # End if

    return args.format, args.set_ntasks, args.set_nthrds, args.header, args.caseroot
# End def parse_command_line


###############################################################################
def get_value_as_string(case, var, attribute=None, resolved=False, subgroup=None):
###############################################################################
    thistype = case.get_type_info(var)
    value = case.get_value(var, attribute=attribute, resolved=resolved, subgroup=subgroup)
    if value is not None and thistype:
        value = convert_to_string(value, thistype, var)
    return value

###############################################################################
def format_pelayout(comp, ntasks, nthreads, rootpe, arg_format):
###############################################################################
    """
    Format the PE layout information for each component, using a default format,
    or using the arg_format input, if it exists.
    """
    subs = { 'C': comp, 'T': ntasks, 'H': nthreads, 'R': rootpe }
    layout_str = re.sub(r"%([0-9]*)C", r"{C:\1}", arg_format)
    layout_str = re.sub(r"%([-+0-9]*)T", r"{T:\1}", layout_str)
    layout_str = re.sub(r"%([-+0-9]*)H", r"{H:\1}", layout_str)
    layout_str = re.sub(r"%([-+0-9]*)R", r"{R:\1}", layout_str)
    layout_str = layout_str.format(**subs)
    return layout_str
# End def format_pelayout

###############################################################################
def print_pelayout(case, ntasks, nthreads, rootpes, arg_format, header):
###############################################################################
    """
    Print the PE layout information for each component, using the format,
     if it exists.
    """
    comp_classes = case.get_values("COMP_CLASSES")

    if (header is not None):
        print(header)
    # End if
    for comp in comp_classes:
        print(format_pelayout(comp, ntasks[comp], nthreads[comp], rootpes[comp], arg_format))
    # End for

# End def print_pelayout

###############################################################################
def gather_pelayout(case):
###############################################################################
    """
    Gather the PE layout information for each component
    """
    ntasks = {}
    nthreads = {}
    rootpes = {}
    comp_classes = case.get_values("COMP_CLASSES")

    for comp in comp_classes:
        ntasks[comp]   = int(case.get_value("NTASKS_"+comp))
        nthreads[comp] = int(case.get_value("NTHRDS_"+comp))
        rootpes[comp]  = int(case.get_value("ROOTPE_"+comp))
    # End for
    return ntasks, nthreads, rootpes
# End def gather_pelayout

###############################################################################
def set_nthreads(case, nthreads):
###############################################################################
    comp_classes = case.get_values("COMP_CLASSES")

    for comp in comp_classes:
        case.set_value("NTHRDS", nthreads, comp)
    # End for
# End def set_nthreads

###############################################################################
def modify_ntasks(case, new_tot_tasks):
###############################################################################
    comp_classes = case.get_values("COMP_CLASSES")
    new_tasks = {}
    new_roots = {}
    curr_tot_tasks = 0

    # First, gather current task and root pe info
    curr_tasks, _, curr_roots = gather_pelayout(case)

    # How many tasks are currently being used?
    for comp in comp_classes:
        if ((curr_tasks[comp] + curr_roots[comp]) > curr_tot_tasks):
            curr_tot_tasks = curr_tasks[comp] + curr_roots[comp]
        # End if
    # End for

    if (new_tot_tasks != curr_tot_tasks):
        # Compute new task counts and root pes
        for comp in comp_classes:
            new_tasks[comp] = curr_tasks[comp] * new_tot_tasks / curr_tot_tasks
            new_roots[comp] = curr_roots[comp] * new_tot_tasks / curr_tot_tasks
        # End for

        # Check for valid recomputation
        mod_ok = True
        for comp in comp_classes:
            if (new_tasks[comp] * curr_tot_tasks / new_tot_tasks) != curr_tasks[comp]:
                logger.error("Task change invalid for {}".format(comp))
                mod_ok = False

            if (new_roots[comp] * curr_tot_tasks / new_tot_tasks) != curr_roots[comp]:
                logger.error("Root PE change invalid for {}".format(comp))
                mod_ok = False
        # End for
        expect(mod_ok, "pelayout unable to set ntasks to {}".format(new_tot_tasks))

        # We got this far? Go ahead and change PE layout
        for comp in comp_classes:
            case.set_value("NTASKS_"+comp, new_tasks[comp], comp)
            case.set_value("ROOTPE_"+comp, new_roots[comp], comp)
        # End for
    # End if (#tasks changed)
# End def modify_ntasks

###############################################################################
def _main_func(description):
###############################################################################
    # Initialize command line parser and get command line options
    arg_format, set_ntasks, set_nthrds, header, caseroot = parse_command_line(sys.argv, description)

    # Initialize case ; read in all xml files from caseroot
    with Case(caseroot) as case:
        if (set_nthrds is not None):
            set_nthreads(case, set_nthrds)
        # End if
        if (set_ntasks is not None):
            modify_ntasks(case, int(set_ntasks))
        # End if
        ntasks, nthreads, rootpes = gather_pelayout(case)
        print_pelayout(case, ntasks, nthreads, rootpes, arg_format, header)
    # End with

# End def _main_func

if (__name__ == "__main__"):
    _main_func(__doc__)
# End if
